from spikingjelly.activation_based import monitor
from spikingjelly.clock_driven.neuron import LIFNode
import spikingjelly.activation_based as snn
from torch.utils.data import DataLoader
import torch.utils.data as data
import torchvision.transforms as transforms
import torch
import matplotlib.pyplot as plt
import pandas as pd
import os
from models import optimal_model
import argparse
from utils import Bar, Logger, AverageMeter, accuracy
import torchvision.datasets as datasets
from spikingjelly.clock_driven import functional
from utils.cifar10_dvs import CIFAR10DVS, ToPILImage, Resize, ToTensor
import math
from spikingjelly.datasets.dvs128_gesture import DVS128Gesture

'''

        This script is used to test the accuracy of the final model and record the spike rate of each layer.


'''


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-dataset', default='dvsgesture', type=str,
                        help='name of the dataset, cifar10, cifar100, dvscifar10, dvsgesture, or tinyimagenet')
    parser.add_argument('-data_dir', type=str, default='./data', help='directory of the dataset')
    parser.add_argument('-b', default=8, type=int, help='batchsize')
    parser.add_argument('-j', default=4, type=int, metavar='N', help='number of loading workers')
    args = parser.parse_args()

    # loading weights (only file_path2 need to be change)
    file_path1 = 'logs/'
    file_path2 = 'Optimal_dvsgesture_spiking_vgg11_T8_tau1.1_e300_bs8_SGD_lr0.05_wd0.0005_drop0.4_losslamb0.05_CosALR_300_11000001'
    file_path3 = '/checkpoint_max.pth'
    file_path = file_path1 + file_path2 + file_path3
    weights = torch.load(file_path, map_location='cpu')

    # manually setting the following hyperparameters based on the loaded file
    model = 'spiking_vgg11'
    optimal_neuron = ''
    int_list = [int(char) for char in optimal_neuron]
    optimal_neuron = torch.tensor(int_list)
    T = 8
    tau = 1.1
    drop_rate = 0.0
    data_dir = args.data_dir

    # dataset processing
    if args.dataset == 'cifar10' or args.dataset == 'cifar100':
        c_in = 3
        if args.dataset == 'cifar10':
            dataloader = datasets.CIFAR10
            num_classes = 10
            normalization_mean = (0.4914, 0.4822, 0.4465)
            normalization_std = (0.2023, 0.1994, 0.2010)
        elif args.dataset == 'cifar100':
            dataloader = datasets.CIFAR100
            num_classes = 100
            normalization_mean = (0.5071, 0.4867, 0.4408)
            normalization_std = (0.2675, 0.2565, 0.2761)
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(normalization_mean, normalization_std),
        ])
        testset = dataloader(root=data_dir, train=False, download=False, transform=transform_test)
        test_data_loader = data.DataLoader(testset, batch_size=args.b, shuffle=False, num_workers=args.j)
    elif args.dataset == 'dvscifar10':
        c_in = 2
        num_classes = 10
        data_dir = os.path.join(data_dir, 'dvscifar10')
        transform_test = transforms.Compose([
            ToPILImage(),
            Resize(48),
            ToTensor(),
        ])
        testset = CIFAR10DVS(data_dir, train=False, use_frame=True, frames_num=T, split_by='number', normalization=None, transform=transform_test)
        test_data_loader = data.DataLoader(testset, batch_size=args.b, shuffle=False, num_workers=args.j)
    elif args.dataset == 'dvsgesture':
        c_in = 2
        num_classes = 11
        data_dir = os.path.join(data_dir, 'dvsgesture')
        testset = DVS128Gesture(root=data_dir, train=False, data_type='frame', frames_number=T, split_by='number')
        test_data_loader = data.DataLoader(testset, batch_size=args.b, shuffle=False, num_workers=args.j, drop_last=False, pin_memory=True)
    elif args.dataset == 'tinyimagenet':
        c_in = 3
        data_dir = os.path.join(data_dir, 'tiny-imagenet-200')
        num_classes = 200
        testdir = os.path.join(data_dir, 'val')
        normalize = transforms.Normalize(mean=[0.4802, 0.4481, 0.3975],
                                         std=[0.2770, 0.2691, 0.2821])
        transform_test = transforms.Compose([
            transforms.Resize(64),
            transforms.ToTensor(),
            normalize,
        ])
        test_dataset = datasets.ImageFolder(testdir, transform=transform_test)
        test_data_loader = DataLoader(test_dataset, batch_size=args.b, shuffle=False, num_workers=args.j, pin_memory=True)
    else:
        raise NotImplementedError

    net = optimal_model.__dict__[model](optimal_neuron=optimal_neuron, num_classes=num_classes, neuron_dropout=drop_rate,
                                                          tau=tau, c_in=c_in)
    net.load_state_dict(weights['net'])
    mtor = monitor.OutputMonitor(net, instance=LIFNode)
    net.cuda()
    net.eval()
    top1 = AverageMeter()
    record = []
    test_acc = 0
    test_samples = 0
    batch_idx = 0
    with torch.no_grad():
        for frame, label in test_data_loader:
            mtor.clear_recorded_data()
            batch_idx += 1
            print('batch_idx:', batch_idx, ' of ', len(test_data_loader))
            if args.dataset != 'dvscifar10':
                frame = frame.float().cuda()
                if args.dataset == 'dvsgesture':
                    frame = frame.transpose(0, 1)
            label = label.cuda()
            for t in range(T):
                if args.dataset == 'dvscifar10':
                    input = frame[t].float().cuda()
                elif args.dataset == 'dvsgesture':
                    input = frame[t]
                else:
                    input = frame
                if t == 0:
                    output = net(input)
                    total_output = output.clone().detach()
                else:
                    output = net(input)
                    total_output += output.clone().detach()
                record_within_timestep = []
                for tensor in mtor.records:
                    num_spikes = tensor.sum().item()
                    total_elements = tensor.numel()
                    ratio = num_spikes / total_elements
                    record_within_timestep.append(ratio)
                if t == 0:
                    record_timestep = record_within_timestep
                else:
                    record_timestep = [record_timestep[i] + record_within_timestep[i] for i in range(len(record_timestep))]
                if t < T - 1:
                    mtor.clear_recorded_data()
            if batch_idx == 1:
                Record = record_timestep
            else:
                Record = [Record[i] + record_timestep[i] for i in range(len(Record))]
            test_samples += label.numel()
            test_acc += (total_output.argmax(1) == label).float().sum().item()
            functional.reset_net(net)
            prec1, prec5 = accuracy(total_output.data, label.data, topk=(1, 5))
            top1.update(prec1.item(), input.size(0))
    Record = [Record[i] / (len(test_data_loader) * T) for i in range(len(Record))]
    test_acc /= test_samples
    Mean_record = sum(Record) / len(Record)
    print('test_acc:', test_acc)
    print(Record)
    print('mean fire rate = ', Mean_record)

    # Insert a breakpoint here and check the variables
    print('1')

if __name__ == '__main__':
    main()
